package cn.edu.dgut.css.sai.security.oauth2.client.web;

import cn.edu.dgut.css.sai.security.oauth2.util.IpAddressUtil;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;

import javax.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;


/**
 * 重构OAuth2AuthorizationRequest，以满足微信、QQ、Dgut等Authorization Code接口的请求参数要求
 * 使用默认的DefaultOAuth2AuthorizationRequestResolver生成OAuth2AuthorizationRequest的from方法获取到builder
 * 再用builder修改、增加参数，最终这些参数会用于构建请求code的url参数
 *
 * @author Sai
 * @see OAuth2AuthorizationRequestResolver
 * @see DefaultOAuth2AuthorizationRequestResolver
 * @see OAuth2AuthorizationRequestRedirectFilter
 * @since 2.2
 */
public final class DelegateSaiOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {

    private final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
    private final OAuth2AuthorizationRequestResolver defaultOAuth2AuthorizationRequestResolver;
    private final AntPathRequestMatcher authorizationRequestMatcher;

    public DelegateSaiOAuth2AuthorizationRequestResolver(ClientRegistrationRepository clientRegistrationRepository,
                                                         String authorizationRequestBaseUri) {
        this.defaultOAuth2AuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, authorizationRequestBaseUri);
        this.authorizationRequestMatcher = new AntPathRequestMatcher(
                authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
    }

    @Override
    public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
        return rebuildOAuth2AuthorizationRequest(defaultOAuth2AuthorizationRequestResolver.resolve(request), request);
    }

    @Override
    public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId) {
        return rebuildOAuth2AuthorizationRequest(defaultOAuth2AuthorizationRequestResolver.resolve(request, clientRegistrationId), request);
    }

    /**
     * @param oauth2AuthorizationRequest {@link OAuth2AuthorizationRequest}
     * @param request                    {@link HttpServletRequest}
     * @return 重构后的OAuth2AuthorizationRequest
     * @since 1.0.5
     */
    private OAuth2AuthorizationRequest rebuildOAuth2AuthorizationRequest(OAuth2AuthorizationRequest oauth2AuthorizationRequest, HttpServletRequest request) {
        if (oauth2AuthorizationRequest == null)
            return null;
        OAuth2AuthorizationRequest.Builder builder = OAuth2AuthorizationRequest.from(oauth2AuthorizationRequest);
        Map<String, Object> additionalParameters = new HashMap<>();
        switch (Objects.requireNonNull(resolveRegistrationId(request))) {
            case "qq":
                // qq登录暂不需要配置
                break;
            case "dingding": // https://oapi.dingtalk.com/connect/qrconnect?appid=APPID&response_type=code&scope=snsapi_login&state=STATE&redirect_uri=REDIRECT_URI
            // 2.3 增加 微信客户端授权登录。
            case "wxmp": // https://developers.weixin.qq.com/doc/offiaccount/OA_Web_Apps/Wechat_webpage_authorization.html
            case "weixin":
                additionalParameters.put("appid", oauth2AuthorizationRequest.getClientId());
                break;
            case "dgut":
                additionalParameters.put("appid", oauth2AuthorizationRequest.getClientId());
                // Dgut的中央认证系统对state的字符有限制：
                // STATE参数用于第三方应用验证参数，非必须，最大长度128，只支持数字、大小写字母和4个特殊符号“_”“/”“.”“=”，如参数不合法，自动过滤掉该参数。
                // 使用spring的默认的生成state的方法，会有"-"，所以这里把"-"更换为"."
                builder.state(oauth2AuthorizationRequest.getState().replace("-", "."));
                break;
            default:
                return oauth2AuthorizationRequest;
        }
        // Dgut中央认证系统在获取access_token时，需要用户IP作为参数
        // 这里把用户IP信息记录在每种第三方登录的oauth2AuthorizationRequest对象里，方便查询
        additionalParameters.put("ip", IpAddressUtil.getRequestIpAdrress(request));
        additionalParameters.putAll(oauth2AuthorizationRequest.getAdditionalParameters());

        return builder.additionalParameters(additionalParameters).build();
    }

    private String resolveRegistrationId(HttpServletRequest request) {
        if (this.authorizationRequestMatcher.matches(request)) {
            return this.authorizationRequestMatcher.
                    matcher(request).getVariables().get(REGISTRATION_ID_URI_VARIABLE_NAME);
        }
        return null;
    }
}
